{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b510dd3f",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "# Bidirectional Encoder Representations from Transformers (BERT)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ffa5c8df",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:51.331685Z",
     "iopub.status.busy": "2023-08-18T19:31:51.331049Z",
     "iopub.status.idle": "2023-08-18T19:31:54.812815Z",
     "shell.execute_reply": "2023-08-18T19:31:54.811897Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe8bcf59",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Input Representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3c018a43",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:54.816440Z",
     "iopub.status.busy": "2023-08-18T19:31:54.816038Z",
     "iopub.status.idle": "2023-08-18T19:31:54.823105Z",
     "shell.execute_reply": "2023-08-18T19:31:54.822123Z"
    },
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_tokens_and_segments(tokens_a, tokens_b=None):\n",
    "    \"\"\"Get tokens of the BERT input sequence and their segment IDs.\"\"\"\n",
    "    tokens = ['<cls>'] + tokens_a + ['<sep>']\n",
    "    segments = [0] * (len(tokens_a) + 2)\n",
    "    if tokens_b is not None:\n",
    "        tokens += tokens_b + ['<sep>']\n",
    "        segments += [1] * (len(tokens_b) + 1)\n",
    "    return tokens, segments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45ba910d",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "`BERTEncoder` class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a90baa53",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:54.828159Z",
     "iopub.status.busy": "2023-08-18T19:31:54.827512Z",
     "iopub.status.idle": "2023-08-18T19:31:54.835645Z",
     "shell.execute_reply": "2023-08-18T19:31:54.834698Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class BERTEncoder(nn.Module):\n",
    "    \"\"\"BERT encoder.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,\n",
    "                 num_blks, dropout, max_len=1000, **kwargs):\n",
    "        super(BERTEncoder, self).__init__(**kwargs)\n",
    "        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)\n",
    "        self.segment_embedding = nn.Embedding(2, num_hiddens)\n",
    "        self.blks = nn.Sequential()\n",
    "        for i in range(num_blks):\n",
    "            self.blks.add_module(f\"{i}\", d2l.TransformerEncoderBlock(\n",
    "                num_hiddens, ffn_num_hiddens, num_heads, dropout, True))\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,\n",
    "                                                      num_hiddens))\n",
    "\n",
    "    def forward(self, tokens, segments, valid_lens):\n",
    "        X = self.token_embedding(tokens) + self.segment_embedding(segments)\n",
    "        X = X + self.pos_embedding[:, :X.shape[1], :]\n",
    "        for blk in self.blks:\n",
    "            X = blk(X, valid_lens)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f8d4e1d",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Inference of `BERTEncoder`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56903d71",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:54.990720Z",
     "iopub.status.busy": "2023-08-18T19:31:54.989911Z",
     "iopub.status.idle": "2023-08-18T19:31:55.148261Z",
     "shell.execute_reply": "2023-08-18T19:31:55.147176Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 8, 768])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4\n",
    "ffn_num_input, num_blks, dropout = 768, 2, 0.2\n",
    "encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,\n",
    "                      num_blks, dropout)\n",
    "\n",
    "tokens = torch.randint(0, vocab_size, (2, 8))\n",
    "segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])\n",
    "encoded_X = encoder(tokens, segments, None)\n",
    "encoded_X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7700375b",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Masked Language Modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a8614e46",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.154295Z",
     "iopub.status.busy": "2023-08-18T19:31:55.153400Z",
     "iopub.status.idle": "2023-08-18T19:31:55.162155Z",
     "shell.execute_reply": "2023-08-18T19:31:55.161271Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class MaskLM(nn.Module):\n",
    "    \"\"\"The masked language model task of BERT.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, **kwargs):\n",
    "        super(MaskLM, self).__init__(**kwargs)\n",
    "        self.mlp = nn.Sequential(nn.LazyLinear(num_hiddens),\n",
    "                                 nn.ReLU(),\n",
    "                                 nn.LayerNorm(num_hiddens),\n",
    "                                 nn.LazyLinear(vocab_size))\n",
    "\n",
    "    def forward(self, X, pred_positions):\n",
    "        num_pred_positions = pred_positions.shape[1]\n",
    "        pred_positions = pred_positions.reshape(-1)\n",
    "        batch_size = X.shape[0]\n",
    "        batch_idx = torch.arange(0, batch_size)\n",
    "        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)\n",
    "        masked_X = X[batch_idx, pred_positions]\n",
    "        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))\n",
    "        mlm_Y_hat = self.mlp(masked_X)\n",
    "        return mlm_Y_hat"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ceb49f61",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The forward inference of `MaskLM`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6b3fc7d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.166188Z",
     "iopub.status.busy": "2023-08-18T19:31:55.165836Z",
     "iopub.status.idle": "2023-08-18T19:31:55.273706Z",
     "shell.execute_reply": "2023-08-18T19:31:55.272680Z"
    },
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 10000])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlm = MaskLM(vocab_size, num_hiddens)\n",
    "mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])\n",
    "mlm_Y_hat = mlm(encoded_X, mlm_positions)\n",
    "mlm_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8d85768b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.278830Z",
     "iopub.status.busy": "2023-08-18T19:31:55.278109Z",
     "iopub.status.idle": "2023-08-18T19:31:55.288546Z",
     "shell.execute_reply": "2023-08-18T19:31:55.287485Z"
    },
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])\n",
    "loss = nn.CrossEntropyLoss(reduction='none')\n",
    "mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))\n",
    "mlm_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf2bde58",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Next Sentence Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c1951876",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.292806Z",
     "iopub.status.busy": "2023-08-18T19:31:55.291904Z",
     "iopub.status.idle": "2023-08-18T19:31:55.298328Z",
     "shell.execute_reply": "2023-08-18T19:31:55.297464Z"
    },
    "origin_pos": 25,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class NextSentencePred(nn.Module):\n",
    "    \"\"\"The next sentence prediction task of BERT.\"\"\"\n",
    "    def __init__(self, **kwargs):\n",
    "        super(NextSentencePred, self).__init__(**kwargs)\n",
    "        self.output = nn.LazyLinear(2)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.output(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57961ff1",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The forward inference of an `NextSentencePred`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aba0cce5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.302539Z",
     "iopub.status.busy": "2023-08-18T19:31:55.301869Z",
     "iopub.status.idle": "2023-08-18T19:31:55.310590Z",
     "shell.execute_reply": "2023-08-18T19:31:55.309427Z"
    },
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_X = torch.flatten(encoded_X, start_dim=1)\n",
    "nsp = NextSentencePred()\n",
    "nsp_Y_hat = nsp(encoded_X)\n",
    "nsp_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ba504299",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.314489Z",
     "iopub.status.busy": "2023-08-18T19:31:55.313852Z",
     "iopub.status.idle": "2023-08-18T19:31:55.321256Z",
     "shell.execute_reply": "2023-08-18T19:31:55.320193Z"
    },
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nsp_y = torch.tensor([0, 1])\n",
    "nsp_l = loss(nsp_Y_hat, nsp_y)\n",
    "nsp_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fd4cc06",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Putting It All Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "73c331cd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:55.325106Z",
     "iopub.status.busy": "2023-08-18T19:31:55.324530Z",
     "iopub.status.idle": "2023-08-18T19:31:55.333301Z",
     "shell.execute_reply": "2023-08-18T19:31:55.332018Z"
    },
    "origin_pos": 34,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class BERTModel(nn.Module):\n",
    "    \"\"\"The BERT model.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,\n",
    "                 num_heads, num_blks, dropout, max_len=1000):\n",
    "        super(BERTModel, self).__init__()\n",
    "        self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,\n",
    "                                   num_heads, num_blks, dropout,\n",
    "                                   max_len=max_len)\n",
    "        self.hidden = nn.Sequential(nn.LazyLinear(num_hiddens),\n",
    "                                    nn.Tanh())\n",
    "        self.mlm = MaskLM(vocab_size, num_hiddens)\n",
    "        self.nsp = NextSentencePred()\n",
    "\n",
    "    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):\n",
    "        encoded_X = self.encoder(tokens, segments, valid_lens)\n",
    "        if pred_positions is not None:\n",
    "            mlm_Y_hat = self.mlm(encoded_X, pred_positions)\n",
    "        else:\n",
    "            mlm_Y_hat = None\n",
    "        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))\n",
    "        return encoded_X, mlm_Y_hat, nsp_Y_hat"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "required_libs": [],
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}